import torch

a = torch.randn(4, 32, 14, 14)
b = torch.randn(1, 32, 14, 1)
c = a + b
print(c.shape)
